-
Notifications
You must be signed in to change notification settings - Fork 447
Implement Warmup-Stable-Decay (WSD) Learning Rate Schedule #2883
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
0bdb5ea to
295c238
Compare
Codecov Report❌ Patch coverage is
📢 Thoughts on this report? Let us know! |
9a359f0 to
76b8800
Compare
gagika
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR!
I have a few minor comments.
76b8800 to
cbc4557
Compare
|
@gagika Thanks for the review! I've updated the code to use types instead of strings. |
d560850 to
acbec6e
Compare
|
One integration test failed with the error below but this seems unrelated to my changes. |
Seems like a transient issue, I retrigerred the tests on your PR |
gagika
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks
A9isha
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you so much for such a detailed and thoughtful PR. I have just a few small comments. Really appreciate the contribution!
src/MaxText/configs/base.yml
Outdated
| warmup_steps_fraction: 0.1 | ||
| lr_schedule_type: 'cosine' # Options: 'cosine' or 'wsd' | ||
| cosine_learning_rate_final_fraction: 0.1 # Final LR as fraction of peak LR for cosine schedule | ||
| wsd_learning_rate_final_fraction: 0.1 # Final LR as fraction of peak LR for WSD schedule |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could we use learning_rate_final_fraction instead of having two separate ones for cosine and wsd?
src/MaxText/maxtext_utils.py
Outdated
| elif config.wsd_decay_style == types.WsdDecayStyle.COSINE: | ||
| decay_schedule = make_cos_schedule(lr, wsd_final_lr, decay_steps) | ||
| else: | ||
| raise ValueError( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we move these ValueError checks to pyconfig.py, then only the clean logic could stay in here
tests/maxtext_utils_test.py
Outdated
|
|
||
| # Decay phase: peak -> final | ||
| lr_mid_decay = schedule_fn(decay_start + decay_steps // 2) | ||
| expected_final = learning_rate * wsd_learning_rate_final_fraction |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could we also check that it reaches the expected_final value at step 1000i.e., after all the decay_steps?
same for the cosine test
f9aa5b3 to
fa2d6ef
Compare
…le stable and decay phases Signed-off-by: bzantium <[email protected]>
fa2d6ef to
0508fe3
Compare
Description
This PR implements the Warmup-Stable-Decay (WSD) learning rate schedule as a configurable option alongside the existing Cosine schedule. This allows users to choose between the standard cosine decay and a schedule that maintains a stable peak learning rate for the majority of training before a rapid decay.
Additionally, this implementation introduces a
wsd_decay_styleparameter, giving users the flexibility to choose the decay profile (linear or cosine) for the final annealing phase.Details and Context:
src/MaxText/configs/base.yml):lr_schedule_type(options:'cosine','wsd').wsd_learning_rate_final_fraction,wsd_decay_steps_fraction.wsd_decay_style: Supports'linear'(default, standard for WSD) or'cosine'decay for the final phase.src/MaxText/configs/types.py):LearningRateScheduleTypeandWsdDecayStyleEnums.Optimizerclass to include validation for these new fields.src/MaxText/maxtext_utils.py):create_learning_rate_scheduleto switch between Cosine and WSD logic.Linear Warmup->Constant Stable->Decay.optax.linear_scheduleand a custom cosine schedule based onwsd_decay_style.warmup_steps_fraction + wsd_decay_steps_fraction <= 1.0.Tests
I have added a comprehensive test suite,
TestLearningRateSchedules, intests/maxtext_utils_test.py.linearandcosinedecay styles.ValueError.To reproduce/test:
Fixes: #2882
Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.